www.gusucode.com > 有监督的 CNN 网络完成对MNIST 数字的识别 > 有监督的 CNN 网络完成对MNIST 数字的识别/CNN—卷积神经网络数字识别/@cnn/cutrain.m
function [out_cnet] = cutrain(cnet,Ip,labels,I_testp, labels_test) %TRAIN train convolutional neural network using stochastic Levenberg-Marquardt % % Syntax % % [cnet, perf_plot] = train(cnet,Ip,labels,I_testp, labtst) % % Description % Input: % cnet - Convolutional neural network class object % Ip - cell array, containing preprocessed images of handwriten digits % labels - cell array of labels, corresponding to images % I_testp - cell array, containing preprocessed images of handwriten % digits of test set % labtst - cell array of labels, corresponding to images of test set % Output: % cnet - trained convolutional neural network % perf_plot - performance data %Initialize GUI h_gui = cnn_gui(); %Progress bars h_HessPatch = findobj(h_gui,'Tag','HessPatch'); h_HessEdit = findobj(h_gui,'Tag','HessPrEdit'); h_TrainPatch = findobj(h_gui,'Tag','TrainPatch'); h_TrainEdit = findobj(h_gui,'Tag','TrainPrEdit'); %Axes h_MCRaxes = findobj(h_gui,'Tag','MCRaxes'); h_RMSEaxes = findobj(h_gui,'Tag','RMSEaxes'); %Info textboxes h_EpEdit = findobj(h_gui,'Tag','EpEdit'); h_ItEdit = findobj(h_gui,'Tag','ItEdit'); h_RMSEedit = findobj(h_gui,'Tag','RMSEedit'); h_MCRedit = findobj(h_gui,'Tag','MCRedit'); h_TetaEdit = findobj(h_gui,'Tag','TetaEdit'); %Buttons h_AbortButton = findobj(h_gui,'Tag','AbortButton'); tic; %Fix the start time perf_plot = []; %Array for storing performance data %Init the cuda cnn singnet = cnn2singlestruct(cnet); cudacnn('init',singnet); %Initial MCR calculation mcr(1)=cucalcMCR(I_testp, labels_test, 1:50); plot(h_MCRaxes,mcr); SetText(h_MCRedit,mcr(end)); numPats = length(Ip); if(cnet.HcalcMode == 1) for i=1:cnet.HrecalcSamplesNum %Setting the right output to 1, others to -1 d = -ones(1,10); d(labels(i)+1) = 1; %Simulating [out, cnet] = sim(cnet,Ip{i}); %Calculate the error e = out-d; %Calculate Jacobian times error, or in other words calculate %gradient [cnet,je] = calcje(cnet,e); [cnet,hx] = calchx(cnet); % jj = jj+diag(sparse(hx)); SetHessianProgress(h_HessPatch,h_HessEdit,i/cnet.HrecalcSamplesNum); end % %Averaging % jj = jj/cnet.HrecalcSamplesNum; end %For all epochs for t=1:cnet.epochs SetText(h_EpEdit,t); SetTextHP(h_TetaEdit,cnet.teta); %For all patterns for n=1:numPats %Setting the right output to 1, others to -1 d = -ones(1,10); d(labels(n)+1) = 1; %Simulating out = cudacnn('sim',single(Ip{n}))'; %Calculate the error e = out-d; %Calculate Jacobian times error, or in other words calculate %gradient cudacnn('adapt',single(e)); %Calculate Hessian diagonal approximation % if(cnet.HcalcMode == 0) % [cnet,hx] = calchx(cnet); % %Calculate the running estimate of Hessian diagonal approximation % jj = gamma*diag(sparse(hx))+sparse((1-gamma)*jj); % else % if(mod(t*numPats+n,cnet.Hrecalc)==0) %If it is time to recalculate Hessian % if(n+cnet.HrecalcSamplesNum>numPats) % stInd = numPats-cnet.HrecalcSamplesNum; % else % stInd = n; % end % for i=stInd:stInd+cnet.HrecalcSamplesNum % %Setting the right output to 1, others to -1 % d = -ones(1,10); % d(labels(i)+1) = 1; % %Simulating % [out, cnet] = sim(cnet,Ip{i}); % %Calculate the error % e = out-d; % %Calculate Jacobian times error, or in other words calculate % %gradient % [cnet,je] = calcje(cnet,e); % [cnet,hx] = calchx(cnet); % jj = jj+diag(sparse(hx)); % % SetHessianProgress(h_HessPatch,h_HessEdit,(i-stInd)/cnet.HrecalcSamplesNum); % end % %Averaging % jj = jj/cnet.HrecalcSamplesNum; % end % end %The following is usefull for debugging. %===========DEBUG % tmp(1)=check_finit_dif(cnet,1,Ip{n},d,1); %===========DEBUG perf(n) = mse(double(e)); %Store the error %Uncoment this if you want a gradient descent %dW = cnet.teta(t)*je; %Actually Levenberg-Marquardt %dW = (jj+cnet.mu*ii)\(cnet.teta*je); %Apply calculated weight updates %cnet = adapt_dw(cnet,dW); %Plot mean of performance for every 10 patterns % if(n>500) if(n>1) if(~mod(n-1,400)) mcr = [mcr cucalcMCR(I_testp, labels_test, 1:200)]; plot(h_MCRaxes,mcr); SetText(h_MCRedit,mcr(end)); end if(~mod(n-1,10)) perf_plot = [perf_plot,mean(sqrt(perf(n-10:n)))]; plot(h_RMSEaxes,perf_plot); SetText(h_RMSEedit,perf_plot(end)); end end % end; %It looks smoother, but not suit to large datasets % perf_plot = [perf_plot,mean(perf)]; % plot(perf_plot); % grid on; SetTrainingProgress(h_TrainPatch,h_TrainEdit,(n+(t-1)*numPats)/(numPats*cnet.epochs)); SetText(h_ItEdit,n); drawnow; if(~isempty(get(h_AbortButton,'UserData'))) fprintf('Training aborted \n'); out_cnet = singlestruct2cnn(cudacnn('save',singnet)); return; end end cnet.teta = cnet.teta*cnet.teta_dec; cudacnn('set_data',0,'teta',single(cnet.teta)); end %display('Train is finished'); %display('Time was '); toc %display('Perfomance is '); %mean(perf) %Save new weights back to cnet out_cnet = singlestruct2cnn(cudacnn('save',singnet)); %Sets Hessian progress %hp - handle of patch %hs - handle of editbox %pr - value from 0 to 1 function SetHessianProgress(hp,hs,pr) xpatch = [0 pr*100 pr*100 0]; set(hp,'XData',xpatch); set(hs,'String',[num2str(pr*100,'%5.2f'),'%']); drawnow; %Sets Training progress %hp - handle of patch %hs - handle of editbox %pr - value from 0 to 1 function SetTrainingProgress(hp,hs,pr) xpatch = [0 pr*100 pr*100 0]; set(hp,'XData',xpatch); set(hs,'String',[num2str(pr*100,'%5.2f'),'%']); %Set numeric text in the specified edit box %hs - handle of textbox %num - number to convert and set function SetText(hs,num) set(hs,'String',num2str(num,'%5.2f')); %Set numeric text in the specified edit box with high preceition %hs - handle of textbox %num - number to convert and set function SetTextHP(hs,num) set(hs,'String',num2str(num,'%5.3e'));